#!/usr/bin/env python

from optparse import OptionParser
import pygtk
pygtk.require("2.0") 
import gobject,gtk,gtk.glade
import sys,pickle,os,glob
import scipy
from pylab import *
import sqlite3

import ocrolib
from ocrolib import dbtables,common
from ocrolib import Record

def TODO():
    print """
    accept-selection button (turns everything else to _)
    redo-selection button (turns selection to _)
    reject-selection button (turns selection to #)
    misseg-selection button (turns selection to ~)
    cluster-classifier using output of manual classification
    special chars: #=reject, _=unclassified, ~=missegmentation
    put confidence into database and display
    nnet+exception
    centering on largest component only
    sort by size, similarity, frequency
    """

parser = OptionParser(usage="""
%prog [options] [input.db]

Edit character- and cluster-databases.

OCRopus isolated character training data is stored in sqlite3 databases
with a simple structure.  Isolated characters are (usually) stored in
a table called "chars", while clusters are stored in a table called
"clusters".  This application lets you browse and label these characters
and clusters.

You need to specify the correct table name for your database using the "-t"
flag.

Updates to the database are transacted and happen immediately; you don't need
to save.  Even concurrent editing and viewing is possible.  There is no
undo, but it's usually easy to manually undo changes.

There is a variety of sorting, training, and classification options available.
Particularly useful is the ability to sort by size and aspect ratio to quickly 
identify characters that have implausible sizes or aspect ratios.  Once you have
a character model, you can also sort by classifier confidence and focus correction
efforts on the characters with the least confidence.

Eventually, you will be able to cluster and train right from this interface, but
for now, use the external ocropus-ctrain and ocropus-cluster-db programs.

""")
parser.add_option("-m","--model",help="model used for classification",default="default.cmodel")
parser.add_option("-v","--verbose",help="verbose",action="store_true")
parser.add_option("-t","--table",help="which table to edit",default=None)
# use this with old (unflipped) cmodels trained from Python
parser.add_option("-F","--flip",help="flip characters before handing to classifier",default=1,action="store_false")
parser.add_option("-B","--batchsize",help="maximum number of characters loaded for display",type="int",default=50000)
(options,args) = parser.parse_args()

ion()

### misc utility functions

def detuple(item):
    """Return the first non-list/tuple element of the
    argument, recursively."""
    while 1:
        if type(item) is tuple:
            item = item[0]
            continue
        if type(item) is list:
            item = item[0]
            continue
        return item

def numpy2pixbuf(a,limit=40):
    """Convert a numpy array to a pixbuf."""
    r = max(a.shape)
    scaled = 0
    if r>limit:
        a = array(a,'f')
        a = scipy.ndimage.interpolation.zoom(a,limit/float(r),order=1)
        scaled = 1
    data = zeros(list(a.shape)+[3],'B')
    data[:,:,0] = 255*a
    data[:,:,1] = 255*a
    data[:,:,2] = 255*a
    if scaled:
        data[:3,:3,:] = 0
        data[:3,:3,1] = 255
    return gtk.gdk.pixbuf_new_from_array(data,gtk.gdk.COLORSPACE_RGB,8)

### load the cluster images

print "opening table"
file = "clusters.db"
if len(args)>0: file = args[0]
if not os.path.exists(file):
    print file,"not found"
    sys.exit(1)
if options.table is None:
    db = sqlite3.connect(file)
    cur = db.cursor()
    names = cur.execute("select name from sqlite_master where type='table'")
    options.table = list(names)[0][0]
    print "defaulting to table",options.table
table = dbtables.Table(file,options.table)
table.converter("image",dbtables.SmallImage())
print "done"

grid = [row[0] for row in table.query("select distinct(grid) from '%s'" % options.table)]
grid = [map(int,x.split()) for x in grid]
rows = amax(array([x[0] for x in grid]))+1
cols = amax(array([x[1] for x in grid]))+1
print "rows",rows,"cols",cols
data = array([[None]*cols]*rows)
classes = array([[None]*cols]*rows)
for row in table.get():
    r,c = map(int,row.grid.split())
    data[r,c] = row.image/255.0
    classes[r,c] = row.cls

def get_clusters():
    result = []
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            result.append(Record(image=data[i,j],cls=classes[i,j],classes="",count=1))
    return result

def get_classes():
    return list(table.get_keys("cls"))

def compute_combolist():
    """Compute the combolist from the current charlist."""
    global charlist,combolist
    select = class_selector.get_active_text()
    charlist = sorted(get_classes())
    charlist = ["_"]+charlist
    combolist = gtk.ListStore(str)
    for char in charlist:
        combolist.append([char])
    class_selector.set_model(combolist)
    class_selector.set_text_column(0)
    if select in charlist:
        which = charlist.index(select)
        class_selector.set_active(which)

def set_store(sortfun=None):
    """Set the store for the target class."""
    global grid
    grid = gtk.ListStore(gtk.gdk.Pixbuf,
                         str,
                         gobject.TYPE_PYOBJECT)
    rownum = 0
    selected = get_clusters()
    if sortfun is not None:
        selected = sortfun(selected)
    for cluster in selected:
        pixbuf = numpy2pixbuf(1.0-cluster.image)
        row = [pixbuf,cluster.cls,cluster]
        grid.append(row)
        rownum += 1
        if rownum>options.batchsize: break
    cluster_viewer.set_model(grid)
    move_to(0)

def move_to(index):
    """Move to the given index in the current view."""
    index = detuple(index)
    cluster_viewer.set_cursor(index)
    cluster_viewer.select_path(index)
    update_info()

def update_info():
    """Update the character information associated with the currently
    selected character."""
    index = cluster_viewer.get_cursor()
    if index is None: return
    index = index[0][0]
    row = grid[index]
    info_area.set_text(str(row[2].classes)+" "+str(row[2].count))

def get_extended():
    """Get a string from a dialog box, used for extended labels."""
    dialog = gtk.MessageDialog(
        None,
        gtk.DIALOG_MODAL|gtk.DIALOG_DESTROY_WITH_PARENT,
        gtk.MESSAGE_QUESTION,
        gtk.BUTTONS_OK,
        None)
    dialog.set_markup("Transcript:")
    entry = gtk.Entry()
    entry.connect("activate",
                  lambda e,d,r: d.response(r),
                  dialog,gtk.RESPONSE_OK)
    dialog.vbox.pack_end(entry,True,True,0)
    dialog.show_all()
    dialog.run()
    text = entry.get_text()
    dialog.destroy()
    return text

def set_dist(x,s):
    assert type(x)==ndarray,x
    assert type(s[0])==ndarray,s[0]
    minerr = 10000
    for y in s:
        if x.shape!=y.shape: continue
        xt = x
        err,rerr,_ = common.symdist(xt,y)
        minerr = min(err,minerr)
    return minerr

### toolbar commands

def cmd_similar(*args):
    """Sort by similarity to the selected item."""
    global grid,cluster_viewer
    assert type(grid)==gtk.ListStore
    if 0:
        selected = [grid[i][2].image for i in cluster_viewer.get_selected_items()]
        dists = [set_dist(x[2].image,selected) for x in grid]
    else:
        selected = [ocrolib.stdsize(grid[i][2].image) for i in cluster_viewer.get_selected_items()]
        dists = [set_dist(ocrolib.stdsize(x[2].image),selected) for x in grid]
    index = array(argsort(dists))
    index = [int(i) for i in index]
    grid.reorder(index)
    cluster_viewer.unselect_all()
    cluster_viewer.scroll_to_path(1,1,0,0)
    return 1

# sorting

def sort_menu_populate(cmd):
    menu = gtk.Menu()
    group = None
    activated = 0
    for sort in ["count","width","height","aspect","area","components","holes","endpoints","junctions"]:
        item = gtk.RadioMenuItem(group=group,label=sort)
        group = item
        item.connect("toggled",cmd,sort)
        menu.append(item)
    if not activated: 
        item.activate() # activate the last one if there is no default
    return menu

def cmd_sel_sort(item,which):
    pass

def cmd_sort(*args):
    pass

# classifying

def classifier_menu(cmd):
    menu = gtk.Menu()
    group = None
    activated = 0
    models = glob.glob("*.cmodel")
    try:
        modeldir = ocrolib.finddir("models")
        models += glob.glob(modeldir+"/*.cmodel")
    except IOError:
        pass
    for model in models:
        item = gtk.RadioMenuItem(group=group,label=model)
        if "default" in model:
            item.activate()
            activated = 1
        group = item
        item.connect("toggled",cmd,model)
        menu.append(item)
    if not activated: 
        item.activate() # activate the last one if there is no default
    return menu

def cmd_sel_class(item,which):
    if item.get_active():
        print "activated class",item,which
        load_classifier(which)

classifier = None

def load_classifier(which=options.model):
    pass

load_classifier()

# confidence 

confidence = None

def load_confidence(which=options.model):
    pass
load_confidence()

def cmd_sel_conf(item,which):
    if item.get_active():
        print "activated conf",item,which
        load_confidence(which)
def cmd_nn(*args):
    pass
def cmd_class(*args):
    return 1

def cmd_train(*args):
    return 1

### basic event handlers

def on_iconview1_item_activated(*args):
    update_info()

def on_iconview1_motion_notify_event(widget,event):
    # item = widget.get_item_at_pos(event.x,event.y)
    # if item is not None: move_to(item[0])
    return 0

def on_iconview1_button_press_event(widget,event):
    if event.button==1: 
        update_info()
        return 0
    if event.button==2:
        item = widget.get_item_at_pos(event.x,event.y)
        if item is not None: 
            s = "_"
            item = detuple(item)
            move_to(item)
            row = grid[item]
            row[1] = s
            row[2].cls = s
            table.put(row[2])

def on_iconview1_key_press_event(widget,event):
    item = cluster_viewer.get_cursor()
    # print item,event.string
    if event.string in ["\027"]: # ^W
        i = class_selector.get_active()
        if i<=0: return 1
        class_selector.set_active(i-1)
        update_info()
        return 1
    if event.string in ["\032"]: # ^Z
        i = class_selector.get_active()
        if i<0: i = 0
        if i>=len(combolist)-1: return 1  # it's one longer than the number of entries
        class_selector.set_active(i+1)
        update_info()
        return 1
    if event.string=="\022": # ^R = reload
        compute_combolist()
        return 1
    if event.string>=" " or event.string=="\004":
        last = -1
        s = event.string
        if s==".":
            s = get_extended()
        else:
            if s in ["\004"]: s = ""
            if s in [" "]: s = "~"
        items = cluster_viewer.get_selected_items()
        if items==[]:
            items = [widget.get_cursor()[0][0]]
        for item in items:
            item = detuple(item)
            if item>last: last = item
            row = grid[item]
            row[1] = s
            row[2].cls = s
            table.put(row[2],commit=0)
        table.commit()
        cluster_viewer.unselect_all()
        move_to(last+1)
        return 1
    return 0

def build_toolbar():
    global toolbar
    toolbar = main_widget_tree.get_widget("toolbar")
    toolbar.set_style(gtk.TOOLBAR_BOTH)
    button = gtk.ToolButton(label="Similar")
    button.connect("clicked",cmd_similar)
    toolbar.insert(button,-1)
    
    # button = gtk.ToolButton(label="Freq")
    # button.connect("clicked",cmd_freq)
    # toolbar.insert(button,-1)

    # sort selector
    sort_button = gtk.MenuToolButton(None,"Sort")
    sort_menu = sort_menu_populate(cmd_sel_sort)
    sort_button.set_menu(sort_menu)
    sort_menu.show_all()
    sort_button.show_all()
    sort_button.connect("clicked",cmd_sort)
    toolbar.insert(sort_button,-1)

    # classify selector
    classify_button = gtk.MenuToolButton(None,"Classify")
    classify_menu = classifier_menu(cmd_sel_class)
    classify_button.set_menu(classify_menu)
    classify_menu.show_all()
    classify_button.show_all()
    classify_button.connect("clicked",cmd_nn)
    toolbar.insert(classify_button,-1)

    # confidence selector
    conf_button = gtk.MenuToolButton(None,"Confidence")
    conf_menu = classifier_menu(cmd_sel_conf)
    conf_button.set_menu(conf_menu)
    conf_menu.show_all()
    conf_button.show_all()
    conf_button.connect("clicked",cmd_class)
    toolbar.insert(conf_button,-1)
    
    train_button = gtk.ToolButton(label="Train")
    train_button.connect("clicked",cmd_train)
    toolbar.insert(train_button,-1)

def main():
    global main_widget_tree,class_selector,cluster_viewer,info_area
    gladefile = ocrolib.findfile("gui/ocroex-cedit.glade")
    windowname = "window1" 
    main_widget_tree = gtk.glade.XML(gladefile)
    dic = {
        "on_window1_destroy_event" : gtk.main_quit,
        "on_window1_delete_event" : gtk.main_quit,
        "on_iconview1_key_press_event" : on_iconview1_key_press_event,
        "on_iconview1_button_press_event" : on_iconview1_button_press_event,
        "on_iconview1_item_activated" : on_iconview1_item_activated,
        "on_iconview1_selection_changed" : on_iconview1_item_activated,
        "on_iconview1_motion_notify_event" : on_iconview1_motion_notify_event,
        }
    main_widget_tree.signal_autoconnect(dic)
    window = main_widget_tree.get_widget("window1")
    build_toolbar()

    graphview = main_widget_tree.get_widget("scrolledwindow1") 
    cluster_viewer = main_widget_tree.get_widget("iconview1")
    cluster_viewer.set_selection_mode(gtk.SELECTION_MULTIPLE)
    class_selector = main_widget_tree.get_widget("comboboxentry1")
    info_area = main_widget_tree.get_widget("info")
    cluster_viewer.set_item_width(50)
    cluster_viewer.set_pixbuf_column(0)
    cluster_viewer.set_text_column(1)
    cluster_viewer.set_columns(cols)
    assert cluster_viewer is not None
    graphview.show_all()
    cluster_viewer.show_all()

    compute_combolist()
    # class_selector.set_active(0)

    status = main_widget_tree.get_widget("status")
    main_widget_tree.get_widget("window1").show_all()
    set_store()
    gtk.main()

main()
